#include "gosdt.hpp"

#define _DEBUG true
#define THROTTLE false

float GOSDT::time = 0.0;
unsigned int GOSDT::size = 0;
unsigned int GOSDT::iterations = 0;
unsigned int GOSDT::status = 0;

GOSDT::GOSDT(void) {}

GOSDT::~GOSDT(void) {
    return;
}

void GOSDT::configure(std::istream & config_source) { Configuration::configure(config_source); }


void GOSDT::fit(std::istream & data_source, std::string & result) {
    std::unordered_set< Model > models;
    fit(data_source, models);
    json output = json::array();
    for (auto iterator = models.begin(); iterator != models.end(); ++iterator) {
        Model model = * iterator;
        // std::cout << "Objective Value: "  << model.loss() + model.complexity() << std::endl;
        // json object = json::object();
        // model.to_json(object);
        // output.push_back(object);
    }
    result = output.dump(2);
}

void GOSDT::fit(std::istream & data_source, std::unordered_set< Model > & models) {
    if(Configuration::verbose) { std::cout << "Using configuration: " << Configuration::to_string(2) << std::endl; }

    if(Configuration::verbose) { std::cout << "Initializing Optimization Framework" << std::endl; }
    Optimizer optimizer;
    optimizer.load(data_source);

    // Dump dataset metadata if requested and terminate early 
    if (Configuration::datatset_encoding != "") {
        json output = json::array();
        for (int binary_feature_index=0; binary_feature_index<State::dataset.encoder.binary_features(); binary_feature_index++) {
            json node = json::object();
            unsigned int feature_index;
            std::string feature_name, feature_type, relation, reference;
            State::dataset.encoder.decode(binary_feature_index, & feature_index);
            State::dataset.encoder.encoding(binary_feature_index, feature_type, relation, reference);
            State::dataset.encoder.header(feature_index, feature_name);

            node["feature"] = feature_index;
            node["name"] = feature_name;
            node["relation"] = relation;
            if (Encoder::test_integral(reference)) {
                node["type"] = "integral";
                node["reference"] = atoi(reference.c_str());
            } else if (Encoder::test_rational(reference)) {
                node["type"] = "rational";
                node["reference"] = atof(reference.c_str());
            } else {
                node["type"] = "categorical";
                node["reference"] = reference;
            }
            output.push_back(node);

        }
        std::string result = output.dump(2);
        if(Configuration::verbose) { std::cout << "Storing Metadata in: " << Configuration::datatset_encoding << std::endl; }
        std::ofstream out(Configuration::datatset_encoding);
        out << result;
        out.close();
        return;
    }

    GOSDT::time = 0.0;
    GOSDT::size = 0;
    GOSDT::iterations = 0;
    GOSDT::status = 0;

    std::vector< std::thread > workers;
    std::vector< int > iterations(Configuration::worker_limit);

    if(Configuration::verbose) { std::cout << "Starting Search for the Optimal Solution" << std::endl; }
    auto start = std::chrono::high_resolution_clock::now();

    optimizer.initialize();
    for (unsigned int i = 0; i < Configuration::worker_limit; ++i) {
        workers.emplace_back(work, i, std::ref(optimizer), std::ref(iterations[i]));
        #ifndef __APPLE__
        if (Configuration::worker_limit > 1) {
            // If using Ubuntu Build, we can pin each thread to a specific CPU core to improve cache locality
            cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(i, &cpuset);
            int error = pthread_setaffinity_np(workers[i].native_handle(), sizeof(cpu_set_t), &cpuset);
            if (error != 0) { std::cerr << "Error calling pthread_setaffinity_np: " << error << std::endl; }
        }
        #endif
    }
    for (auto iterator = workers.begin(); iterator != workers.end(); ++iterator) { (* iterator).join(); } // Wait for the thread pool to terminate
    
    auto stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
    GOSDT::time = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count() / 1000.0;
    if(Configuration::verbose) { std::cout << "Optimal Solution Search Complete" << std::endl; }

    for (auto iterator = iterations.begin(); iterator != iterations.end(); ++iterator) { GOSDT::iterations += * iterator; }    
    GOSDT::size = optimizer.size();

    if (Configuration::timing != "") {
        std::ofstream timing_output(Configuration::timing, std::ios_base::app);
        timing_output << GOSDT::time;
        timing_output.flush();
        timing_output.close();
    }

    if(Configuration::verbose) {
        std::cout << "Training Duration: " << GOSDT::time << " seconds" << std::endl;
        std::cout << "Number of Iterations: " << GOSDT::iterations << " iterations" << std::endl;
        std::cout << "Size of Graph: " << GOSDT::size << " nodes" << std::endl;
        float lowerbound, upperbound;
        optimizer.objective_boundary(& lowerbound, & upperbound);
        std::cout << "Objective Boundary: [" << lowerbound << ", " << upperbound << "]" << std::endl;
        std::cout << "Optimality Gap: " << optimizer.uncertainty() << std::endl;
    }

    // try 
    { // Model Extraction
        if (!optimizer.complete()) {
            GOSDT::status = 1;
            if (Configuration::diagnostics) {
                std::cout << "Non-convergence Detected. Beginning Diagnosis" << std::endl;
                optimizer.diagnose_non_convergence();
                std::cout << "Diagnosis complete" << std::endl;
            }
        }
        optimizer.models(models);

        if (Configuration::model_limit > 0 && models.size() == 0) {
            GOSDT::status = 1;
            if (Configuration::diagnostics) {
                std::cout << "False-convergence Detected. Beginning Diagnosis" << std::endl;
                optimizer.diagnose_false_convergence();
                std::cout << "Diagnosis complete" << std::endl;
            }
        }

        if (Configuration::verbose) {
            std::cout << "Models Generated: " << models.size() << std::endl;
            if (optimizer.uncertainty() == 0.0 && models.size() > 0) {
                std::cout << "Loss: " << models.begin() -> loss() << std::endl;
                std::cout << "Complexity: " << models.begin() -> complexity() << std::endl;
            } 
        }
        if (Configuration::model != "") {
            json output = json::array();
            for (auto iterator = models.begin(); iterator != models.end(); ++iterator) {
                Model model = * iterator;
                json object = json::object();
                model.to_json(object);
                output.push_back(object);
            }
            std::string result = output.dump(2);
            if(Configuration::verbose) { std::cout << "Storing Models in: " << Configuration::model << std::endl; }
            std::ofstream out(Configuration::model);
            out << result;
            out.close();
        }


    }

    // Extraction of Rashomon Set 
    if (Configuration::rashomon) { 

        std::cout << "OVJECTIVE: " << models.begin() -> loss() + models.begin() -> complexity() << std::endl;
        float rashomon_bound = models.begin() -> loss() + models.begin() -> complexity();
        optimizer.reset_except_dataset();
        fit(optimizer, rashomon_bound, models);

        /* workers.clear(); 
        iterations.clear();
        models.clear();
        if(Configuration::verbose) { std::cout << "Starting Extraction of Rashomon Set" << std::endl; }
        auto start = std::chrono::high_resolution_clock::now();

        optimizer.initialize();
        for (unsigned int i = 0; i < Configuration::worker_limit; ++i) {
            workers.emplace_back(work, i, std::ref(optimizer), std::ref(iterations[i]));
            #ifndef __APPLE__
            if (Configuration::worker_limit > 1) {
                // If using Ubuntu Build, we can pin each thread to a specific CPU core to improve cache locality
                cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(i, &cpuset);
                int error = pthread_setaffinity_np(workers[i].native_handle(), sizeof(cpu_set_t), &cpuset);
                if (error != 0) { std::cerr << "Error calling pthread_setaffinity_np: " << error << std::endl; }
            }
            #endif
        }
        for (auto iterator = workers.begin(); iterator != workers.end(); ++iterator) { (* iterator).join(); } // Wait for the thread pool to terminate
        
        auto stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
        GOSDT::time = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count() / 1000.0;
        if(Configuration::verbose) { std::cout << "Optimal Solution Search Complete" << std::endl; } */
    }


    //  catch (IntegrityViolation exception) {
    //     GOSDT::status = 1;
    //     std::cout << exception.to_string() << std::endl;
    // }
}


void GOSDT::fit(Optimizer & optimizer, float rashomon_bound, std::unordered_set< Model > & models) {
    //std::unordered_set< Model > models;


    if(Configuration::verbose) { std::cout << "Using configuration: " << Configuration::to_string(2) << std::endl; }

    if(Configuration::verbose) { std::cout << "Initializing Optimization Framework" << std::endl; }
  
    GOSDT::time = 0.0;
    GOSDT::size = 0;
    GOSDT::iterations = 0;
    GOSDT::status = 0;
    std::vector< std::thread > workers;
    std::vector< int > iterations(Configuration::worker_limit);

    if(Configuration::verbose) { std::cout << "Starting Extraction of Rashomon Set" << std::endl; }
    auto start = std::chrono::high_resolution_clock::now();

    optimizer.initialize();
    rashomon_bound += rashomon_bound * Configuration::rashomon_bound_multiplier;
    std::cout << "Rashomon bound: " << rashomon_bound << std::endl;
    optimizer.set_rashomon_bound(rashomon_bound);
    optimizer.set_rashomon_flag();
    for (unsigned int i = 0; i < Configuration::worker_limit; ++i) {
        workers.emplace_back(work, i, std::ref(optimizer), std::ref(iterations[i]));
        #ifndef __APPLE__
        if (Configuration::worker_limit > 1) {
            // If using Ubuntu Build, we can pin each thread to a specific CPU core to improve cache locality
            cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(i, &cpuset);
            int error = pthread_setaffinity_np(workers[i].native_handle(), sizeof(cpu_set_t), &cpuset);
            if (error != 0) { std::cerr << "Error calling pthread_setaffinity_np: " << error << std::endl; }
        }
        #endif
    }
    for (auto iterator = workers.begin(); iterator != workers.end(); ++iterator) { (* iterator).join(); } // Wait for the thread pool to terminate
    
    auto stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
    GOSDT::time = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count() / 1000.0;
    if(Configuration::verbose) { std::cout << "Rashomon Set Construction Completed" << std::endl; }

    for (auto iterator = iterations.begin(); iterator != iterations.end(); ++iterator) { GOSDT::iterations += * iterator; }    
    GOSDT::size = optimizer.size();

    if (Configuration::timing != "") {
        std::ofstream timing_output(Configuration::timing, std::ios_base::app);
        timing_output << GOSDT::time;
        timing_output.flush();
        timing_output.close();
    }

    if(Configuration::verbose) {
        std::cout << "Training Duration: " << GOSDT::time << " seconds" << std::endl;
        std::cout << "Number of Iterations: " << GOSDT::iterations << " iterations" << std::endl;
        std::cout << "Size of Graph: " << GOSDT::size << " nodes" << std::endl;
        float lowerbound, upperbound;
        optimizer.objective_boundary(& lowerbound, & upperbound);
        std::cout << "Objective Boundary: [" << lowerbound << ", " << upperbound << "]" << std::endl;
        std::cout << "Optimality Gap: " << optimizer.uncertainty() << std::endl;
    }

    auto extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time
    optimizer.models(models);
    auto extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
    if (Configuration::verbose) {
        float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
        std::cout << "Extraction Duration: " << time << " seconds" << std::endl;
    }
    

    //  catch (IntegrityViolation exception) {
    //     GOSDT::status = 1;
    //     std::cout << exception.to_string() << std::endl;
    // }

    if (Configuration::model_limit > 0 && models.size() == 0) {
        GOSDT::status = 1;
        if (Configuration::diagnostics) {
            std::cout << "False-convergence Detected. Beginning Diagnosis" << std::endl;
            optimizer.diagnose_false_convergence();
            std::cout << "Diagnosis complete" << std::endl;
        }
    }

    if (Configuration::verbose) {
        std::cout << "Size of Rashomon set: " << models.size() << std::endl;
        if (optimizer.uncertainty() == 0.0 && models.size() > 0) {
            std::cout << "Loss: " << models.begin() -> loss() << std::endl;
            std::cout << "Complexity: " << models.begin() -> complexity() << std::endl;
        } 
    }
    if (Configuration::rashomon_model != "") {
        json output = json::array();
        for (auto iterator = models.begin(); iterator != models.end(); ++iterator) {
            Model model = * iterator;
            json object = json::object();
            model.to_json(object);
            output.push_back(object);
        }
        std::string result = output.dump(2);
        if(Configuration::verbose) { std::cout << "Storing Models in: " << Configuration::rashomon_model << std::endl; }
        std::ofstream out(Configuration::rashomon_model);
        out << result;
        out.close();
    }
    // if (Configuration::rashomon_trie != "") {
    //     bool calculate_size = false;
    //     char const *type = "node";
    //     Trie* tree = new Trie(calculate_size, type);
    //     tree->insert_root();
    //     for (auto iterator = models.begin(); iterator != models.end(); ++iterator) {
    //         tree->insert_model(&(*iterator));
    //     }

    //     std::string serialization;
    //     tree->serialize(serialization, 2);
    //     // std::cout << serialization << std::endl;

    //     if(Configuration::verbose) { std::cout << "Storing Models in: " << Configuration::rashomon_trie << std::endl; }
    //     std::ofstream out(Configuration::rashomon_trie);
    //     out << serialization;
    //     out.close();
    // }


}


void GOSDT::work(int const id, Optimizer & optimizer, int & return_reference) {
    unsigned int iterations = 0;
    try {
        while (optimizer.iterate(id)) { iterations += 1; }
    } catch( IntegrityViolation exception ) {
        GOSDT::status = 1;
        std::cout << exception.to_string() << std::endl;
        throw std::move(exception);
    }
    return_reference = iterations;
}